"""
Probe life expectancy turning point and inflation channel dominance.

Investigates:
1. Is the 41.8 LE turning point driven by developing countries?
2. Does it return to ~60+ for OECD?
3. Is the inflation channel dominance universal or developing-country-specific?
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/japanification")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

# ---------------------------------------------------------------------------
# Load data
# ---------------------------------------------------------------------------
df = pd.read_csv(PROJECT_DIR / "data" / "processed" / "japan_panel_indexed.csv")
print("=" * 80)
print("COLUMNS:", list(df.columns))
print(f"Shape: {df.shape}")
print(f"Countries: {df['iso3'].nunique()}, Years: {df['year'].min()}-{df['year'].max()}")
print(f"LE range: {df['life_expectancy'].min():.1f} - {df['life_expectancy'].max():.1f}")
print(f"japan_index_2c non-null: {df['japan_index_2c'].notna().sum()}")
print(f"z_growth non-null: {df['z_growth'].notna().sum()}")
print(f"z_inflation non-null: {df['z_inflation'].notna().sum()}")
print(f"gdp_pc_ppp non-null: {df['gdp_pc_ppp'].notna().sum()}")
print("=" * 80)

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

df["oecd"] = df["iso3"].isin(OECD_38).astype(int)
df["le_sq"] = df["life_expectancy"] ** 2

CONTROLS = ["fiscal_bal_gdp", "kaopen", "log_rel_opw", "nfa_gdp_lag"]


def run_le_quadratic(data, dep_var, label):
    """Run LE + LE squared + controls regression, return results dict."""
    cols = [dep_var, "life_expectancy", "le_sq"] + CONTROLS + ["iso3", "year"]
    sub = data.dropna(subset=cols).copy()
    if len(sub) < 30:
        return None

    X_vars = ["life_expectancy", "le_sq"] + CONTROLS
    X = sub[X_vars].values
    y = sub[dep_var].values
    entity = sub["iso3"].values
    time = sub["year"].values

    gls = PanelGLS()
    gls.fit(y, X, entity, time)

    beta_le = gls.beta[0]
    beta_le2 = gls.beta[1]
    turning_point = -beta_le / (2 * beta_le2) if beta_le2 != 0 else np.nan

    return {
        "label": label,
        "n_obs": gls.n_obs,
        "n_countries": gls.n_countries,
        "r_squared": gls.r_squared,
        "beta_le": beta_le,
        "se_le": gls.se[0],
        "p_le": gls.pvalues[0],
        "beta_le2": beta_le2,
        "se_le2": gls.se[1],
        "p_le2": gls.pvalues[1],
        "turning_point": turning_point,
    }


def run_z_channel(data, dep_var, label):
    """Run Z1, Z2, Z3 + controls regression."""
    z_vars = ["Z_1", "Z_2", "Z_3"]
    cols = [dep_var] + z_vars + CONTROLS + ["iso3", "year"]
    sub = data.dropna(subset=cols).copy()
    if len(sub) < 30:
        return None

    X_vars = z_vars + CONTROLS
    X = sub[X_vars].values
    y = sub[dep_var].values
    entity = sub["iso3"].values
    time = sub["year"].values

    gls = PanelGLS()
    gls.fit(y, X, entity, time)

    result = {
        "label": label,
        "n_obs": gls.n_obs,
        "n_countries": gls.n_countries,
        "r_squared": gls.r_squared,
    }
    for i, v in enumerate(X_vars):
        result[f"beta_{v}"] = gls.beta[i]
        result[f"se_{v}"] = gls.se[i]
        result[f"p_{v}"] = gls.pvalues[i]
    return result


def stars(p):
    if p < 0.001:
        return "***"
    elif p < 0.01:
        return "**"
    elif p < 0.05:
        return "*"
    elif p < 0.1:
        return "+"
    return ""


# ===================================================================
# 1-3. LE Quadratic: Full, OECD, Non-OECD
# ===================================================================
print("\n" + "=" * 80)
print("SECTION 1-3: LE QUADRATIC -- FULL, OECD, NON-OECD")
print("=" * 80)

results_le = []
for label, mask in [
    ("Full sample", df.index == df.index),
    ("OECD only", df["oecd"] == 1),
    ("Non-OECD", df["oecd"] == 0),
]:
    r = run_le_quadratic(df[mask], "japan_index_2c", label)
    if r:
        results_le.append(r)
        print(f"\n--- {label} ---")
        print(f"  N={r['n_obs']}, Countries={r['n_countries']}, R2={r['r_squared']:.3f}")
        print(f"  LE:  b={r['beta_le']:.4f} (se={r['se_le']:.4f}, p={r['p_le']:.4f}){stars(r['p_le'])}")
        print(f"  LE2: b={r['beta_le2']:.6f} (se={r['se_le2']:.6f}, p={r['p_le2']:.4f}){stars(r['p_le2'])}")
        print(f"  >>> Turning point: {r['turning_point']:.1f} years")
    else:
        print(f"\n--- {label} --- INSUFFICIENT DATA")

# ===================================================================
# 4. LE Quadratic by income quartile
# ===================================================================
print("\n" + "=" * 80)
print("SECTION 4: LE QUADRATIC BY INCOME QUARTILE")
print("=" * 80)

results_inc = []
if df["gdp_pc_ppp"].notna().sum() > 100:
    country_gdp = df.groupby("iso3")["gdp_pc_ppp"].median()
    quartile_labels = pd.qcut(country_gdp, 4, labels=["Q1 (poorest)", "Q2", "Q3", "Q4 (richest)"])
    df["income_quartile"] = df["iso3"].map(quartile_labels)

    for q in ["Q1 (poorest)", "Q2", "Q3", "Q4 (richest)"]:
        mask = df["income_quartile"] == q
        r = run_le_quadratic(df[mask], "japan_index_2c", f"Income {q}")
        if r:
            results_inc.append(r)
            sub_le = df.loc[mask, "life_expectancy"].dropna()
            print(f"\n--- Income {q} (LE range: {sub_le.min():.1f}-{sub_le.max():.1f}) ---")
            print(f"  N={r['n_obs']}, Countries={r['n_countries']}, R2={r['r_squared']:.3f}")
            print(f"  LE:  b={r['beta_le']:.4f} (se={r['se_le']:.4f}, p={r['p_le']:.4f}){stars(r['p_le'])}")
            print(f"  LE2: b={r['beta_le2']:.6f} (se={r['se_le2']:.6f}, p={r['p_le2']:.4f}){stars(r['p_le2'])}")
            print(f"  >>> Turning point: {r['turning_point']:.1f} years")
        else:
            print(f"\n--- Income {q} --- INSUFFICIENT DATA")
else:
    print("gdp_pc_ppp not available -- skipping income quartile analysis")

# ===================================================================
# 5. Channel decomposition: OECD vs Non-OECD
# ===================================================================
print("\n" + "=" * 80)
print("SECTION 5: CHANNEL DECOMPOSITION -- OECD vs NON-OECD")
print("=" * 80)

channel_vars = [c for c in ["z_growth", "z_inflation", "z_rate"] if c in df.columns]
print(f"Channel variables found: {channel_vars}")

results_channel = []
for dep_var in channel_vars:
    for label, mask in [
        ("OECD", df["oecd"] == 1),
        ("Non-OECD", df["oecd"] == 0),
        ("Full", df.index == df.index),
    ]:
        r = run_z_channel(df[mask], dep_var, f"{dep_var} | {label}")
        if r:
            results_channel.append(r)
            print(f"\n--- {dep_var} | {label} ---")
            print(f"  N={r['n_obs']}, Countries={r['n_countries']}, R2={r['r_squared']:.3f}")
            for zv in ["Z_1", "Z_2", "Z_3"]:
                b = r[f"beta_{zv}"]
                s = r[f"se_{zv}"]
                p = r[f"p_{zv}"]
                print(f"  {zv}: b={b:.3f} (se={s:.3f}, p={p:.4f}){stars(p)}")
        else:
            print(f"\n--- {dep_var} | {label} --- INSUFFICIENT DATA")

# ===================================================================
# 6. Channel decomposition pre/post GFC
# ===================================================================
print("\n" + "=" * 80)
print("SECTION 6: CHANNEL DECOMPOSITION -- PRE/POST GFC")
print("=" * 80)

results_gfc = []
for dep_var in channel_vars:
    for label, mask in [
        ("Pre-GFC (<=2007)", df["year"] <= 2007),
        ("Post-GFC (>=2008)", df["year"] >= 2008),
    ]:
        r = run_z_channel(df[mask], dep_var, f"{dep_var} | {label}")
        if r:
            results_gfc.append(r)
            print(f"\n--- {dep_var} | {label} ---")
            print(f"  N={r['n_obs']}, Countries={r['n_countries']}, R2={r['r_squared']:.3f}")
            for zv in ["Z_1", "Z_2", "Z_3"]:
                b = r[f"beta_{zv}"]
                s = r[f"se_{zv}"]
                p = r[f"p_{zv}"]
                print(f"  {zv}: b={b:.3f} (se={s:.3f}, p={p:.4f}){stars(p)}")
        else:
            print(f"\n--- {dep_var} | {label} --- INSUFFICIENT DATA")

# ===================================================================
# Bonus: LE quadratic on each channel (full sample)
# ===================================================================
print("\n" + "=" * 80)
print("BONUS: LE QUADRATIC ON EACH CHANNEL (full sample)")
print("=" * 80)

results_le_channel = []
for dep_var in channel_vars:
    r = run_le_quadratic(df, dep_var, f"LE quad -> {dep_var}")
    if r:
        results_le_channel.append(r)
        print(f"\n--- {dep_var} ---")
        print(f"  N={r['n_obs']}, R2={r['r_squared']:.3f}")
        print(f"  LE:  b={r['beta_le']:.4f} (p={r['p_le']:.4f}){stars(r['p_le'])}")
        print(f"  LE2: b={r['beta_le2']:.6f} (p={r['p_le2']:.4f}){stars(r['p_le2'])}")
        print(f"  >>> Turning point: {r['turning_point']:.1f} years")

# ===================================================================
# Save markdown
# ===================================================================
output_dir = PROJECT_DIR / "output" / "tables"
output_dir.mkdir(parents=True, exist_ok=True)

md = []
md.append("# Life Expectancy Turning Point Probe\n")
md.append("**Key question**: Is the ~41.8 LE turning point driven by developing countries?")
md.append("Does it return to ~60+ for OECD? Is the inflation channel dominance universal?\n")

md.append("## 1. LE Quadratic by Sample\n")
md.append("| Sample | N | Countries | R2 | b(LE) | p(LE) | b(LE2) | p(LE2) | Turning Pt |")
md.append("|--------|---|-----------|-----|-------|-------|--------|--------|------------|")
for r in results_le:
    md.append(
        f"| {r['label']} | {r['n_obs']} | {r['n_countries']} | {r['r_squared']:.3f} "
        f"| {r['beta_le']:.4f}{stars(r['p_le'])} | {r['p_le']:.4f} "
        f"| {r['beta_le2']:.6f}{stars(r['p_le2'])} | {r['p_le2']:.4f} "
        f"| **{r['turning_point']:.1f}** |"
    )

if results_inc:
    md.append("\n## 2. LE Quadratic by Income Quartile\n")
    md.append("| Quartile | N | Countries | R2 | b(LE) | p(LE) | b(LE2) | p(LE2) | Turning Pt |")
    md.append("|----------|---|-----------|-----|-------|-------|--------|--------|------------|")
    for r in results_inc:
        md.append(
            f"| {r['label']} | {r['n_obs']} | {r['n_countries']} | {r['r_squared']:.3f} "
            f"| {r['beta_le']:.4f}{stars(r['p_le'])} | {r['p_le']:.4f} "
            f"| {r['beta_le2']:.6f}{stars(r['p_le2'])} | {r['p_le2']:.4f} "
            f"| **{r['turning_point']:.1f}** |"
        )

md.append("\n## 3. Channel Decomposition: Z -> Component (OECD vs Non-OECD)\n")
md.append("| Channel | Sample | N | R2 | Z1 (b) | Z1 (p) | Z2 (b) | Z2 (p) | Z3 (b) | Z3 (p) |")
md.append("|---------|--------|---|-----|---------|---------|---------|---------|---------|---------|")
for r in results_channel:
    parts = r['label'].split(' | ')
    md.append(
        f"| {parts[0]} | {parts[1]} "
        f"| {r['n_obs']} | {r['r_squared']:.3f} "
        f"| {r['beta_Z_1']:.3f}{stars(r['p_Z_1'])} | {r['p_Z_1']:.4f} "
        f"| {r['beta_Z_2']:.3f}{stars(r['p_Z_2'])} | {r['p_Z_2']:.4f} "
        f"| {r['beta_Z_3']:.3f}{stars(r['p_Z_3'])} | {r['p_Z_3']:.4f} |"
    )

md.append("\n## 4. Channel Decomposition: Pre vs Post GFC\n")
md.append("| Channel | Period | N | R2 | Z1 (b) | Z1 (p) | Z2 (b) | Z2 (p) | Z3 (b) | Z3 (p) |")
md.append("|---------|--------|---|-----|---------|---------|---------|---------|---------|---------|")
for r in results_gfc:
    parts = r['label'].split(' | ')
    md.append(
        f"| {parts[0]} | {parts[1]} "
        f"| {r['n_obs']} | {r['r_squared']:.3f} "
        f"| {r['beta_Z_1']:.3f}{stars(r['p_Z_1'])} | {r['p_Z_1']:.4f} "
        f"| {r['beta_Z_2']:.3f}{stars(r['p_Z_2'])} | {r['p_Z_2']:.4f} "
        f"| {r['beta_Z_3']:.3f}{stars(r['p_Z_3'])} | {r['p_Z_3']:.4f} |"
    )

if results_le_channel:
    md.append("\n## 5. LE Quadratic on Each Channel (Full Sample)\n")
    md.append("| Channel | N | R2 | b(LE) | p(LE) | b(LE2) | p(LE2) | Turning Pt |")
    md.append("|---------|---|-----|-------|-------|--------|--------|------------|")
    for r in results_le_channel:
        md.append(
            f"| {r['label']} | {r['n_obs']} | {r['r_squared']:.3f} "
            f"| {r['beta_le']:.4f}{stars(r['p_le'])} | {r['p_le']:.4f} "
            f"| {r['beta_le2']:.6f}{stars(r['p_le2'])} | {r['p_le2']:.4f} "
            f"| **{r['turning_point']:.1f}** |"
        )

md.append("\n## Interpretation\n")
md.append("*(To be filled based on results above)*\n")

md_text = "\n".join(md)
outpath = output_dir / "le_turning_point_probe.md"
outpath.write_text(md_text)
print(f"\n\nSaved to: {outpath}")
print("Done.")
